{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import time\n",
    "from torch.optim import lr_scheduler\n",
    "\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import random_split\n",
    "\n",
    "from torchvision import transforms, models\n",
    "from MyEDFImports import load_all_labels, stages_names_3_outputs, three_stages_transform\n",
    "from tempfile import TemporaryDirectory"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-12T13:00:37.239903447Z",
     "start_time": "2023-08-12T13:00:12.707601674Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2023-08-12T13:02:10.201650373Z",
     "start_time": "2023-08-12T13:02:07.124882125Z"
    }
   },
   "outputs": [],
   "source": [
    "data_dir = '../images_cwt_ssq/images_(19248, 224, 224)_wav_morlet_sqpy.npy'\n",
    "data_np = np.load(data_dir)\n",
    "targets = load_all_labels()\n",
    "targets = three_stages_transform(targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "class DatasetFromNp(Dataset):\n",
    "    def __init__(self, data, target, transform=None):\n",
    "        self.data = data\n",
    "        self.target = target\n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        x = self.data[index]\n",
    "        # Extending 1channel image to be put to three channels\n",
    "        x = torch.from_numpy(x)\n",
    "        x = torch.unsqueeze(x, 0)\n",
    "        x = x.expand(3, -1, -1)\n",
    "        y = self.target[index]\n",
    "        if self.transform:\n",
    "            x = self.transform(x)\n",
    "        return x, y\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.target)\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-11T17:19:56.801350271Z",
     "start_time": "2023-08-11T17:19:56.794344380Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "# setting up tranforms for the images (just normalizing pretty much)\n",
    "data_transforms = transforms.Compose([\n",
    "    # to tensor can probably be just torch.from_numpy\n",
    "    # transforms.ToTensor(),\n",
    "    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "dataset_all = DatasetFromNp(data_np, target=targets, transform=data_transforms)\n",
    "\n",
    "generator = torch.Generator()  # .manual_seed(22)\n",
    "train_data, test_data = random_split(dataset_all, [0.8, 0.2], generator=generator)\n",
    "datasets = {'train': train_data, 'val': test_data}\n",
    "dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in\n",
    "               ['train', 'val']}\n",
    "dataset_sizes = {x: len(datasets[x]) for x in ['train', 'val']}\n",
    "# remove a fixed generator for training\n",
    "\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-11T17:20:17.249726406Z",
     "start_time": "2023-08-11T17:20:17.195333862Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "Invalid shape (3, 228, 906) for image data",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mTypeError\u001B[0m                                 Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[6], line 17\u001B[0m\n\u001B[1;32m     15\u001B[0m inputs, classes \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mnext\u001B[39m(\u001B[38;5;28miter\u001B[39m(dataloaders[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mtrain\u001B[39m\u001B[38;5;124m'\u001B[39m]))\n\u001B[1;32m     16\u001B[0m grid \u001B[38;5;241m=\u001B[39m torchvision\u001B[38;5;241m.\u001B[39mutils\u001B[38;5;241m.\u001B[39mmake_grid(inputs)\n\u001B[0;32m---> 17\u001B[0m \u001B[43mimshow\u001B[49m\u001B[43m(\u001B[49m\u001B[43mgrid\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtitle\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43m \u001B[49m\u001B[43mclasses\u001B[49m\u001B[43m)\u001B[49m\n",
      "Cell \u001B[0;32mIn[6], line 8\u001B[0m, in \u001B[0;36mimshow\u001B[0;34m(inp, title)\u001B[0m\n\u001B[1;32m      2\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mimshow\u001B[39m(inp, title\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m):\n\u001B[1;32m      3\u001B[0m     \u001B[38;5;66;03m# inp = inp.numpy().transpose((1, 2, 0))\u001B[39;00m\n\u001B[1;32m      4\u001B[0m     \u001B[38;5;66;03m# mean = np.array([0.485, 0.456, 0.406])\u001B[39;00m\n\u001B[1;32m      5\u001B[0m     \u001B[38;5;66;03m# std = np.array([0.229, 0.224, 0.225])\u001B[39;00m\n\u001B[1;32m      6\u001B[0m     \u001B[38;5;66;03m# inp = std * inp + mean\u001B[39;00m\n\u001B[1;32m      7\u001B[0m     \u001B[38;5;66;03m# inp = np.clip(inp, 0, 1)\u001B[39;00m\n\u001B[0;32m----> 8\u001B[0m     \u001B[43mplt\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mimshow\u001B[49m\u001B[43m(\u001B[49m\u001B[43minp\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m      9\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m title \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m     10\u001B[0m         plt\u001B[38;5;241m.\u001B[39mtitle(title)\n",
      "File \u001B[0;32m~/miniconda3/envs/tf/lib/python3.9/site-packages/matplotlib/pyplot.py:2695\u001B[0m, in \u001B[0;36mimshow\u001B[0;34m(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, data, **kwargs)\u001B[0m\n\u001B[1;32m   2689\u001B[0m \u001B[38;5;129m@_copy_docstring_and_deprecators\u001B[39m(Axes\u001B[38;5;241m.\u001B[39mimshow)\n\u001B[1;32m   2690\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mimshow\u001B[39m(\n\u001B[1;32m   2691\u001B[0m         X, cmap\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, norm\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, \u001B[38;5;241m*\u001B[39m, aspect\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, interpolation\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m,\n\u001B[1;32m   2692\u001B[0m         alpha\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, vmin\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, vmax\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, origin\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, extent\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m,\n\u001B[1;32m   2693\u001B[0m         interpolation_stage\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, filternorm\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m, filterrad\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m4.0\u001B[39m,\n\u001B[1;32m   2694\u001B[0m         resample\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, url\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, data\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m-> 2695\u001B[0m     __ret \u001B[38;5;241m=\u001B[39m \u001B[43mgca\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mimshow\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m   2696\u001B[0m \u001B[43m        \u001B[49m\u001B[43mX\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcmap\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mcmap\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnorm\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnorm\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maspect\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43maspect\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m   2697\u001B[0m \u001B[43m        \u001B[49m\u001B[43minterpolation\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minterpolation\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43malpha\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43malpha\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mvmin\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mvmin\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m   2698\u001B[0m \u001B[43m        \u001B[49m\u001B[43mvmax\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mvmax\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43morigin\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43morigin\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mextent\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mextent\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m   2699\u001B[0m \u001B[43m        \u001B[49m\u001B[43minterpolation_stage\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minterpolation_stage\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m   2700\u001B[0m \u001B[43m        \u001B[49m\u001B[43mfilternorm\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mfilternorm\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mfilterrad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mfilterrad\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mresample\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mresample\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m   2701\u001B[0m \u001B[43m        \u001B[49m\u001B[43murl\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43murl\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m{\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mdata\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m:\u001B[49m\u001B[43m \u001B[49m\u001B[43mdata\u001B[49m\u001B[43m}\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43;01mif\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[43mdata\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;129;43;01mis\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[38;5;129;43;01mnot\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[38;5;28;43;01mNone\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[38;5;28;43;01melse\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[43m{\u001B[49m\u001B[43m}\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m   2702\u001B[0m \u001B[43m        \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m   2703\u001B[0m     sci(__ret)\n\u001B[1;32m   2704\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m __ret\n",
      "File \u001B[0;32m~/miniconda3/envs/tf/lib/python3.9/site-packages/matplotlib/__init__.py:1442\u001B[0m, in \u001B[0;36m_preprocess_data.<locals>.inner\u001B[0;34m(ax, data, *args, **kwargs)\u001B[0m\n\u001B[1;32m   1439\u001B[0m \u001B[38;5;129m@functools\u001B[39m\u001B[38;5;241m.\u001B[39mwraps(func)\n\u001B[1;32m   1440\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21minner\u001B[39m(ax, \u001B[38;5;241m*\u001B[39margs, data\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[1;32m   1441\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m data \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m-> 1442\u001B[0m         \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[43max\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;28;43mmap\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43msanitize_sequence\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43margs\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m   1444\u001B[0m     bound \u001B[38;5;241m=\u001B[39m new_sig\u001B[38;5;241m.\u001B[39mbind(ax, \u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[1;32m   1445\u001B[0m     auto_label \u001B[38;5;241m=\u001B[39m (bound\u001B[38;5;241m.\u001B[39marguments\u001B[38;5;241m.\u001B[39mget(label_namer)\n\u001B[1;32m   1446\u001B[0m                   \u001B[38;5;129;01mor\u001B[39;00m bound\u001B[38;5;241m.\u001B[39mkwargs\u001B[38;5;241m.\u001B[39mget(label_namer))\n",
      "File \u001B[0;32m~/miniconda3/envs/tf/lib/python3.9/site-packages/matplotlib/axes/_axes.py:5665\u001B[0m, in \u001B[0;36mAxes.imshow\u001B[0;34m(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, **kwargs)\u001B[0m\n\u001B[1;32m   5657\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mset_aspect(aspect)\n\u001B[1;32m   5658\u001B[0m im \u001B[38;5;241m=\u001B[39m mimage\u001B[38;5;241m.\u001B[39mAxesImage(\u001B[38;5;28mself\u001B[39m, cmap\u001B[38;5;241m=\u001B[39mcmap, norm\u001B[38;5;241m=\u001B[39mnorm,\n\u001B[1;32m   5659\u001B[0m                       interpolation\u001B[38;5;241m=\u001B[39minterpolation, origin\u001B[38;5;241m=\u001B[39morigin,\n\u001B[1;32m   5660\u001B[0m                       extent\u001B[38;5;241m=\u001B[39mextent, filternorm\u001B[38;5;241m=\u001B[39mfilternorm,\n\u001B[1;32m   5661\u001B[0m                       filterrad\u001B[38;5;241m=\u001B[39mfilterrad, resample\u001B[38;5;241m=\u001B[39mresample,\n\u001B[1;32m   5662\u001B[0m                       interpolation_stage\u001B[38;5;241m=\u001B[39minterpolation_stage,\n\u001B[1;32m   5663\u001B[0m                       \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m-> 5665\u001B[0m \u001B[43mim\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mset_data\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m   5666\u001B[0m im\u001B[38;5;241m.\u001B[39mset_alpha(alpha)\n\u001B[1;32m   5667\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m im\u001B[38;5;241m.\u001B[39mget_clip_path() \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m   5668\u001B[0m     \u001B[38;5;66;03m# image does not already have clipping set, clip to axes patch\u001B[39;00m\n",
      "File \u001B[0;32m~/miniconda3/envs/tf/lib/python3.9/site-packages/matplotlib/image.py:710\u001B[0m, in \u001B[0;36m_ImageBase.set_data\u001B[0;34m(self, A)\u001B[0m\n\u001B[1;32m    706\u001B[0m     \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A[:, :, \u001B[38;5;241m0\u001B[39m]\n\u001B[1;32m    708\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mndim \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m2\u001B[39m\n\u001B[1;32m    709\u001B[0m         \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mndim \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m3\u001B[39m \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m] \u001B[38;5;129;01min\u001B[39;00m [\u001B[38;5;241m3\u001B[39m, \u001B[38;5;241m4\u001B[39m]):\n\u001B[0;32m--> 710\u001B[0m     \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mInvalid shape \u001B[39m\u001B[38;5;132;01m{}\u001B[39;00m\u001B[38;5;124m for image data\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m    711\u001B[0m                     \u001B[38;5;241m.\u001B[39mformat(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mshape))\n\u001B[1;32m    713\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mndim \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m3\u001B[39m:\n\u001B[1;32m    714\u001B[0m     \u001B[38;5;66;03m# If the input data has values outside the valid range (after\u001B[39;00m\n\u001B[1;32m    715\u001B[0m     \u001B[38;5;66;03m# normalisation), we issue a warning and then clip X to the bounds\u001B[39;00m\n\u001B[1;32m    716\u001B[0m     \u001B[38;5;66;03m# - otherwise casting wraps extreme values, hiding outliers and\u001B[39;00m\n\u001B[1;32m    717\u001B[0m     \u001B[38;5;66;03m# making reliable interpretation impossible.\u001B[39;00m\n\u001B[1;32m    718\u001B[0m     high \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m255\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m np\u001B[38;5;241m.\u001B[39missubdtype(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mdtype, np\u001B[38;5;241m.\u001B[39minteger) \u001B[38;5;28;01melse\u001B[39;00m \u001B[38;5;241m1\u001B[39m\n",
      "\u001B[0;31mTypeError\u001B[0m: Invalid shape (3, 228, 906) for image data"
     ]
    },
    {
     "data": {
      "text/plain": "<Figure size 640x480 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAAGiCAYAAACGUJO6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAa6klEQVR4nO3de2xUZf7H8c+0Q6fIbscIWgvUWlzQKhGXNlTKVqMrNUAwJLuhhg0FFxMbdSt0caF2I0JMGt3IrrfWCxRiUthGBZc/usr8sUK57IVua4xtogG0RVubltAWcQcpz+8P0vk5tmjP0Atf+34l5495PGfmmSd13pwzM63POecEAIAxcaM9AQAAYkHAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACZ5Dtj+/fu1ePFiTZ48WT6fT++8884PHrNv3z5lZmYqMTFR06ZN0yuvvBLLXAEAiPAcsK+++kqzZs3SSy+9NKj9jx8/roULFyo3N1f19fV64oknVFRUpLffftvzZAEA6OO7lF/m6/P5tHv3bi1ZsuSi+6xbt0579uxRU1NTZKywsFAffPCBDh8+HOtDAwDGOP9wP8Dhw4eVl5cXNXbvvfdq69at+uabbzRu3Lh+x4TDYYXD4cjt8+fP6+TJk5o4caJ8Pt9wTxkAMIScc+rp6dHkyZMVFzd0H70Y9oC1tbUpOTk5aiw5OVnnzp1TR0eHUlJS+h1TVlamjRs3DvfUAAAjqKWlRVOnTh2y+xv2gEnqd9bUd9XyYmdTJSUlKi4ujtzu6urSddddp5aWFiUlJQ3fRAEAQ667u1upqan66U9/OqT3O+wBu/baa9XW1hY11t7eLr/fr4kTJw54TCAQUCAQ6DeelJREwADAqKF+C2jYvwc2d+5chUKhqLG9e/cqKytrwPe/AAAYDM8BO336tBoaGtTQ0CDpwsfkGxoa1NzcLOnC5b+CgoLI/oWFhfrss89UXFyspqYmVVZWauvWrVq7du3QPAMAwJjk+RLikSNHdNddd0Vu971XtWLFCm3fvl2tra2RmElSenq6ampqtGbNGr388suaPHmyXnjhBf3qV78agukDAMaqS/oe2Ejp7u5WMBhUV1cX74EBgDHD9RrO70IEAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJMQWsvLxc6enpSkxMVGZmpmpra793/6qqKs2aNUtXXHGFUlJS9MADD6izszOmCQMAIMUQsOrqaq1evVqlpaWqr69Xbm6uFixYoObm5gH3P3DggAoKCrRq1Sp99NFHevPNN/Wf//xHDz744CVPHgAwdnkO2ObNm7Vq1So9+OCDysjI0F/+8helpqaqoqJiwP3/+c9/6vrrr1dRUZHS09P1i1/8Qg899JCOHDlyyZMHAIxdngJ29uxZ1dXVKS8vL2o8Ly9Phw4dGvCYnJwcnThxQjU1NXLO6csvv9Rbb72lRYsWXfRxwuGwuru7ozYAAL7NU8A6OjrU29ur5OTkqPHk5GS1tbUNeExOTo6qqqqUn5+vhIQEXXvttbryyiv14osvXvRxysrKFAwGI1tqaqqXaQIAxoCYPsTh8/mibjvn+o31aWxsVFFRkZ588knV1dXp3Xff1fHjx1VYWHjR+y8pKVFXV1dka2lpiWWaAIAfMb+XnSdNmqT4+Ph+Z1vt7e39zsr6lJWVad68eXr88cclSbfeeqsmTJig3NxcPf3000pJSel3TCAQUCAQ8DI1AMAY4+kMLCEhQZmZmQqFQlHjoVBIOTk5Ax5z5swZxcVFP0x8fLykC2duAADEwvMlxOLiYm3ZskWVlZVqamrSmjVr1NzcHLkkWFJSooKCgsj+ixcv1q5du1RRUaFjx47p4MGDKioq0pw5czR58uSheyYAgDHF0yVEScrPz1dnZ6c2bdqk1tZWzZw5UzU1NUpLS5Mktba2Rn0nbOXKlerp6dFLL72k3//+97ryyit1991365lnnhm6ZwEAGHN8zsB1vO7ubgWDQXV1dSkpKWm0pwMA8GC4XsP5XYgAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADAppoCVl5crPT1diYmJyszMVG1t7ffuHw6HVVpaqrS0NAUCAd1www2qrKyMacIAAEiS3+sB1dXVWr16tcrLyzVv3jy9+uqrWrBggRobG3XdddcNeMzSpUv15ZdfauvWrfrZz36m9vZ2nTt37pInDwAYu3zOOeflgOzsbM2ePVsVFRWRsYyMDC1ZskRlZWX99n/33Xd1//3369ixY7rqqqtimmR3d7eCwaC6urqUlJQU030AAEbHcL2Ge7qEePbsWdXV1SkvLy9qPC8vT4cOHRrwmD179igrK0vPPvuspkyZohkzZmjt2rX6+uuvL/o44XBY3d3dURsAAN/m6RJiR0eHent7lZycHDWenJystra2AY85duyYDhw4oMTERO3evVsdHR16+OGHdfLkyYu+D1ZWVqaNGzd6mRoAYIyJ6UMcPp8v6rZzrt9Yn/Pnz8vn86mqqkpz5szRwoULtXnzZm3fvv2iZ2ElJSXq6uqKbC0tLbFMEwDwI+bpDGzSpEmKj4/vd7bV3t7e76ysT0pKiqZMmaJgMBgZy8jIkHNOJ06c0PTp0/sdEwgEFAgEvEwNADDGeDoDS0hIUGZmpkKhUNR4KBRSTk7OgMfMmzdPX3zxhU6fPh0Z+/jjjxUXF6epU6fGMGUAAGK4hFhcXKwtW7aosrJSTU1NWrNmjZqbm1VYWCjpwuW/goKCyP7Lli3TxIkT9cADD6ixsVH79+/X448/rt/+9rcaP3780D0TAMCY4vl7YPn5+ers7NSmTZvU2tqqmTNnqqamRmlpaZKk1tZWNTc3R/b/yU9+olAopN/97nfKysrSxIkTtXTpUj399NND9ywAAGOO5++BjQa+BwYAdl0W3wMDAOByQcAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASTEFrLy8XOnp6UpMTFRmZqZqa2sHddzBgwfl9/t12223xfKwAABEeA5YdXW1Vq9erdLSUtXX1ys3N1cLFixQc3Pz9x7X1dWlgoIC/fKXv4x5sgAA9PE555yXA7KzszV79mxVVFRExjIyMrRkyRKVlZVd9Lj7779f06dPV3x8vN555x01NDRcdN9wOKxwOBy53d3drdTUVHV1dSkpKcnLdAEAo6y7u1vBYHDIX8M9nYGdPXtWdXV1ysvLixrPy8vToUOHLnrctm3bdPToUW3YsGFQj1NWVqZgMBjZUlNTvUwTADAGeApYR0eHent7lZycHDWenJystra2AY/55JNPtH79elVVVcnv9w/qcUpKStTV1RXZWlpavEwTADAGDK4o3+Hz+aJuO+f6jUlSb2+vli1bpo0bN2rGjBmDvv9AIKBAIBDL1AAAY4SngE2aNEnx8fH9zrba29v7nZVJUk9Pj44cOaL6+no9+uijkqTz58/LOSe/36+9e/fq7rvvvoTpAwDGKk+XEBMSEpSZmalQKBQ1HgqFlJOT02//pKQkffjhh2poaIhshYWFuvHGG9XQ0KDs7OxLmz0AYMzyfAmxuLhYy5cvV1ZWlubOnavXXntNzc3NKiwslHTh/avPP/9cb7zxhuLi4jRz5syo46+55holJib2GwcAwAvPAcvPz1dnZ6c2bdqk1tZWzZw5UzU1NUpLS5Mktba2/uB3wgAAuFSevwc2GobrOwQAgOF3WXwPDACAywUBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACbFFLDy8nKlp6crMTFRmZmZqq2tvei+u3bt0vz583X11VcrKSlJc+fO1XvvvRfzhAEAkGIIWHV1tVavXq3S0lLV19crNzdXCxYsUHNz84D779+/X/Pnz1dNTY3q6up01113afHixaqvr7/kyQMAxi6fc855OSA7O1uzZ89WRUVFZCwjI0NLlixRWVnZoO7jlltuUX5+vp588skB/3s4HFY4HI7c7u7uVmpqqrq6upSUlORlugCAUdbd3a1gMDjkr+GezsDOnj2ruro65eXlRY3n5eXp0KFDg7qP8+fPq6enR1ddddVF9ykrK1MwGIxsqampXqYJABgDPAWso6NDvb29Sk5OjhpPTk5WW1vboO7jueee01dffaWlS5dedJ+SkhJ1dXVFtpaWFi/TBACMAf5YDvL5fFG3nXP9xgayc+dOPfXUU/rb3/6ma6655qL7BQIBBQKBWKYGABgjPAVs0qRJio+P73e21d7e3u+s7Luqq6u1atUqvfnmm7rnnnu8zxQAgG/xdAkxISFBmZmZCoVCUeOhUEg5OTkXPW7nzp1auXKlduzYoUWLFsU2UwAAvsXzJcTi4mItX75cWVlZmjt3rl577TU1NzersLBQ0oX3rz7//HO98cYbki7Eq6CgQM8//7xuv/32yNnb+PHjFQwGh/CpAADGEs8By8/PV2dnpzZt2qTW1lbNnDlTNTU1SktLkyS1trZGfSfs1Vdf1blz5/TII4/okUceiYyvWLFC27dvv/RnAAAYkzx/D2w0DNd3CAAAw++y+B4YAACXCwIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATIopYOXl5UpPT1diYqIyMzNVW1v7vfvv27dPmZmZSkxM1LRp0/TKK6/ENFkAAPp4Dlh1dbVWr16t0tJS1dfXKzc3VwsWLFBzc/OA+x8/flwLFy5Ubm6u6uvr9cQTT6ioqEhvv/32JU8eADB2+ZxzzssB2dnZmj17tioqKiJjGRkZWrJkicrKyvrtv27dOu3Zs0dNTU2RscLCQn3wwQc6fPjwgI8RDocVDocjt7u6unTdddeppaVFSUlJXqYLABhl3d3dSk1N1alTpxQMBofujp0H4XDYxcfHu127dkWNFxUVuTvuuGPAY3Jzc11RUVHU2K5du5zf73dnz54d8JgNGzY4SWxsbGxsP6Lt6NGjXpLzg/zyoKOjQ729vUpOTo4aT05OVltb24DHtLW1Dbj/uXPn1NHRoZSUlH7HlJSUqLi4OHL71KlTSktLU3Nz89DW+0em7185nKl+P9ZpcFinwWGdfljfVbSrrrpqSO/XU8D6+Hy+qNvOuX5jP7T/QON9AoGAAoFAv/FgMMgPyCAkJSWxToPAOg0O6zQ4rNMPi4sb2g++e7q3SZMmKT4+vt/ZVnt7e7+zrD7XXnvtgPv7/X5NnDjR43QBALjAU8ASEhKUmZmpUCgUNR4KhZSTkzPgMXPnzu23/969e5WVlaVx48Z5nC4AABd4Pp8rLi7Wli1bVFlZqaamJq1Zs0bNzc0qLCyUdOH9q4KCgsj+hYWF+uyzz1RcXKympiZVVlZq69atWrt27aAfMxAIaMOGDQNeVsT/Y50Gh3UaHNZpcFinHzZca+T5Y/TShS8yP/vss2ptbdXMmTP15z//WXfccYckaeXKlfr000/1/vvvR/bft2+f1qxZo48++kiTJ0/WunXrIsEDACAWMQUMAIDRxu9CBACYRMAAACYRMACASQQMAGDSZRMw/kTL4HhZp127dmn+/Pm6+uqrlZSUpLlz5+q9994bwdmODq8/S30OHjwov9+v2267bXgneJnwuk7hcFilpaVKS0tTIBDQDTfcoMrKyhGa7ejxuk5VVVWaNWuWrrjiCqWkpOiBBx5QZ2fnCM12dOzfv1+LFy/W5MmT5fP59M477/zgMUPyGj6kv1kxRn/961/duHHj3Ouvv+4aGxvdY4895iZMmOA+++yzAfc/duyYu+KKK9xjjz3mGhsb3euvv+7GjRvn3nrrrRGe+cjyuk6PPfaYe+aZZ9y///1v9/HHH7uSkhI3btw499///neEZz5yvK5Rn1OnTrlp06a5vLw8N2vWrJGZ7CiKZZ3uu+8+l52d7UKhkDt+/Lj717/+5Q4ePDiCsx55XteptrbWxcXFueeff94dO3bM1dbWultuucUtWbJkhGc+smpqalxpaal7++23nSS3e/fu791/qF7DL4uAzZkzxxUWFkaN3XTTTW79+vUD7v+HP/zB3XTTTVFjDz30kLv99tuHbY6XA6/rNJCbb77Zbdy4caindtmIdY3y8/PdH//4R7dhw4YxETCv6/T3v//dBYNB19nZORLTu2x4Xac//elPbtq0aVFjL7zwgps6deqwzfFyM5iADdVr+KhfQjx79qzq6uqUl5cXNZ6Xl6dDhw4NeMzhw4f77X/vvffqyJEj+uabb4ZtrqMplnX6rvPnz6unp2fIfyP05SLWNdq2bZuOHj2qDRs2DPcULwuxrNOePXuUlZWlZ599VlOmTNGMGTO0du1aff311yMx5VERyzrl5OToxIkTqqmpkXNOX375pd566y0tWrRoJKZsxlC9hsf02+iH0kj9iRbrYlmn73ruuef01VdfaenSpcMxxVEXyxp98sknWr9+vWpra+X3j/r/DiMilnU6duyYDhw4oMTERO3evVsdHR16+OGHdfLkyR/t+2CxrFNOTo6qqqqUn5+v//3vfzp37pzuu+8+vfjiiyMxZTOG6jV81M/A+gz3n2j5sfC6Tn127typp556StXV1brmmmuGa3qXhcGuUW9vr5YtW6aNGzdqxowZIzW9y4aXn6Xz58/L5/OpqqpKc+bM0cKFC7V582Zt3779R30WJnlbp8bGRhUVFenJJ59UXV2d3n33XR0/fpxfnTeAoXgNH/V/cvInWgYnlnXqU11drVWrVunNN9/UPffcM5zTHFVe16inp0dHjhxRfX29Hn30UUkXXqidc/L7/dq7d6/uvvvuEZn7SIrlZyklJUVTpkyJ+oOyGRkZcs7pxIkTmj59+rDOeTTEsk5lZWWaN2+eHn/8cUnSrbfeqgkTJig3N1dPP/30j/LqUCyG6jV81M/A+BMtgxPLOkkXzrxWrlypHTt2/Oivw3tdo6SkJH344YdqaGiIbIWFhbrxxhvV0NCg7OzskZr6iIrlZ2nevHn64osvdPr06cjYxx9/rLi4OE2dOnVY5ztaYlmnM2fO9PujjfHx8ZL+/wwDQ/ga7ukjH8Ok76OqW7dudY2NjW716tVuwoQJ7tNPP3XOObd+/Xq3fPnyyP59H8Fcs2aNa2xsdFu3bh1TH6Mf7Drt2LHD+f1+9/LLL7vW1tbIdurUqdF6CsPO6xp911j5FKLXderp6XFTp051v/71r91HH33k9u3b56ZPn+4efPDB0XoKI8LrOm3bts35/X5XXl7ujh496g4cOOCysrLcnDlzRuspjIienh5XX1/v6uvrnSS3efNmV19fH/m6wXC9hl8WAXPOuZdfftmlpaW5hIQEN3v2bLdv377If1uxYoW78847o/Z///333c9//nOXkJDgrr/+eldRUTHCMx4dXtbpzjvvdJL6bStWrBj5iY8grz9L3zZWAuac93Vqampy99xzjxs/frybOnWqKy4udmfOnBnhWY88r+v0wgsvuJtvvtmNHz/epaSkuN/85jfuxIkTIzzrkfWPf/zje19rhus1nD+nAgAwadTfAwMAIBYEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmPR/vVBObw9VdzEAAAAASUVORK5CYII="
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# this imshow should be fixed doesn't work yet\n",
    "def imshow(inp, title=None):\n",
    "    # inp = inp.numpy().transpose((1, 2, 0))\n",
    "    # mean = np.array([0.485, 0.456, 0.406])\n",
    "    # std = np.array([0.229, 0.224, 0.225])\n",
    "    # inp = std * inp + mean\n",
    "    # inp = np.clip(inp, 0, 1)\n",
    "    plt.imshow(inp)\n",
    "    if title is not None:\n",
    "        plt.title(title)\n",
    "    plt.pause(0.001)\n",
    "\n",
    "\n",
    "# not necessarily test_data subset has to be there can be anything else\n",
    "inputs, classes = next(iter(dataloaders['train']))\n",
    "grid = torchvision.utils.make_grid(inputs)\n",
    "# imshow(grid, title= classes)\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-11T17:20:33.701066641Z",
     "start_time": "2023-08-11T17:20:31.042900821Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "def train_model(model, criterion, optimizer, scheduler, num_epochs=25):\n",
    "    since = time.time()\n",
    "\n",
    "    # Create a temporary directory to save training checkpoints\n",
    "    with TemporaryDirectory() as tempdir:\n",
    "        best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')\n",
    "\n",
    "        torch.save(model.state_dict(), best_model_params_path)\n",
    "        best_acc = 0.0\n",
    "\n",
    "        for epoch in range(num_epochs):\n",
    "            print(f'Epoch {epoch + 1}/{num_epochs}')\n",
    "            print('-' * 10)\n",
    "\n",
    "            # Each epoch has a training and validation phase\n",
    "            for phase in ['train', 'val']:\n",
    "                if phase == 'train':\n",
    "                    model.train()  # Set model to training mode\n",
    "                else:\n",
    "                    model.eval()  # Set model to evaluate mode\n",
    "\n",
    "                running_loss = 0.0\n",
    "                running_corrects = 0\n",
    "\n",
    "                # Iterate over data.\n",
    "                for inputs, labels in dataloaders[phase]:\n",
    "                    # No idea why now for inputs why I need to transfer it to a float from a double\n",
    "                    inputs = inputs.to(device, dtype=torch.float)\n",
    "                    labels = labels.to(device)\n",
    "\n",
    "                    # zero the parameter gradients\n",
    "                    optimizer.zero_grad()\n",
    "\n",
    "                    # forward\n",
    "                    # track history if only in train\n",
    "                    with torch.set_grad_enabled(phase == 'train'):\n",
    "                        outputs = model(inputs)\n",
    "                        _, preds = torch.max(outputs, 1)\n",
    "                        loss = criterion(outputs, labels)\n",
    "\n",
    "                        # backward + optimize only if in training phase\n",
    "                        if phase == 'train':\n",
    "                            loss.backward()\n",
    "                            optimizer.step()\n",
    "\n",
    "                    # statistics\n",
    "                    running_loss += loss.item() * inputs.size(0)\n",
    "                    running_corrects += torch.sum(preds == labels.data)\n",
    "                if phase == 'train':\n",
    "                    scheduler.step()\n",
    "\n",
    "                epoch_loss = running_loss / dataset_sizes[phase]\n",
    "                epoch_acc = running_corrects.double() / dataset_sizes[phase]\n",
    "\n",
    "                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')\n",
    "\n",
    "                # deep copy the model\n",
    "                if phase == 'val' and epoch_acc > best_acc:\n",
    "                    best_acc = epoch_acc\n",
    "                    torch.save(model.state_dict(), best_model_params_path)\n",
    "\n",
    "            print()\n",
    "\n",
    "        time_elapsed = time.time() - since\n",
    "        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')\n",
    "        print(f'Best val Acc: {best_acc:4f}')\n",
    "\n",
    "        # load best model weights\n",
    "        model.load_state_dict(torch.load(best_model_params_path))\n",
    "    return model"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-11T17:20:43.592575418Z",
     "start_time": "2023-08-11T17:20:43.576068834Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "data": {
      "text/plain": "ResNet(\n  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n  (relu): ReLU(inplace=True)\n  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n  (layer1): Sequential(\n    (0): Bottleneck(\n      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n      (downsample): Sequential(\n        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      )\n    )\n    (1): Bottleneck(\n      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (2): Bottleneck(\n      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n  )\n  (layer2): Sequential(\n    (0): Bottleneck(\n      (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n      (downsample): Sequential(\n        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      )\n    )\n    (1): Bottleneck(\n      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (2): Bottleneck(\n      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (3): Bottleneck(\n      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (4): Bottleneck(\n      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (5): Bottleneck(\n      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (6): Bottleneck(\n      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (7): Bottleneck(\n      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n  )\n  (layer3): Sequential(\n    (0): Bottleneck(\n      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n      (downsample): Sequential(\n        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n        (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      )\n    )\n    (1): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (2): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (3): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (4): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (5): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (6): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (7): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (8): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (9): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (10): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (11): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (12): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (13): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (14): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (15): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (16): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (17): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (18): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (19): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (20): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (21): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (22): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (23): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (24): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (25): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (26): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (27): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (28): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (29): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (30): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (31): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (32): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (33): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (34): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (35): Bottleneck(\n      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n  )\n  (layer4): Sequential(\n    (0): Bottleneck(\n      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n      (downsample): Sequential(\n        (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n        (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      )\n    )\n    (1): Bottleneck(\n      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n    (2): Bottleneck(\n      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n    )\n  )\n  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n  (fc): Linear(in_features=2048, out_features=1000, bias=True)\n)"
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_152 = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)\n",
    "num_ftrs = model_152.fc.in_features\n",
    "model_152"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-11T17:21:27.338394445Z",
     "start_time": "2023-08-11T17:21:26.375901400Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "outputs": [],
   "source": [
    "model_152.fc = nn.Linear(num_ftrs, 3)\n",
    "model_152.to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer_152 = optim.SGD(model_152.parameters(), lr=0.001, momentum=0.9)\n",
    "exp_lr_scheduler = lr_scheduler.StepLR(optimizer_152, step_size=7, gamma=0.1)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T09:34:41.314008856Z",
     "start_time": "2023-06-16T09:34:41.158008144Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/25\n",
      "----------\n",
      "train Loss: 0.8587 Acc: 0.6537\n",
      "val Loss: 0.9394 Acc: 0.6345\n",
      "\n",
      "Epoch 2/25\n",
      "----------\n",
      "train Loss: 0.7049 Acc: 0.7217\n",
      "val Loss: 0.9643 Acc: 0.5310\n",
      "\n",
      "Epoch 3/25\n",
      "----------\n",
      "train Loss: 0.6087 Acc: 0.7604\n",
      "val Loss: 0.9421 Acc: 0.6345\n",
      "\n",
      "Epoch 4/25\n",
      "----------\n",
      "train Loss: 0.5300 Acc: 0.7949\n",
      "val Loss: 0.9875 Acc: 0.6261\n",
      "\n",
      "Epoch 5/25\n",
      "----------\n",
      "train Loss: 0.4850 Acc: 0.8138\n",
      "val Loss: 1.0176 Acc: 0.6339\n",
      "\n",
      "Epoch 6/25\n",
      "----------\n",
      "train Loss: 0.4497 Acc: 0.8245\n",
      "val Loss: 0.9691 Acc: 0.6345\n",
      "\n",
      "Epoch 7/25\n",
      "----------\n",
      "train Loss: 0.4079 Acc: 0.8441\n",
      "val Loss: 0.9163 Acc: 0.6345\n",
      "\n",
      "Epoch 8/25\n",
      "----------\n",
      "train Loss: 0.3011 Acc: 0.8898\n",
      "val Loss: 1.3557 Acc: 0.1839\n",
      "\n",
      "Epoch 9/25\n",
      "----------\n",
      "train Loss: 0.2678 Acc: 0.8989\n",
      "val Loss: 0.8414 Acc: 0.6547\n",
      "\n",
      "Epoch 10/25\n",
      "----------\n",
      "train Loss: 0.2458 Acc: 0.9072\n",
      "val Loss: 0.8368 Acc: 0.6581\n",
      "\n",
      "Epoch 11/25\n",
      "----------\n",
      "train Loss: 0.2295 Acc: 0.9144\n",
      "val Loss: 0.3117 Acc: 0.8844\n",
      "\n",
      "Epoch 12/25\n",
      "----------\n",
      "train Loss: 0.2138 Acc: 0.9193\n",
      "val Loss: 1.9859 Acc: 0.1855\n",
      "\n",
      "Epoch 13/25\n",
      "----------\n",
      "train Loss: 0.1964 Acc: 0.9277\n",
      "val Loss: 0.3111 Acc: 0.8867\n",
      "\n",
      "Epoch 14/25\n",
      "----------\n",
      "train Loss: 0.1808 Acc: 0.9336\n",
      "val Loss: 1.1081 Acc: 0.5931\n",
      "\n",
      "Epoch 15/25\n",
      "----------\n",
      "train Loss: 0.1482 Acc: 0.9468\n",
      "val Loss: 0.3195 Acc: 0.8901\n",
      "\n",
      "Epoch 16/25\n",
      "----------\n",
      "train Loss: 0.1414 Acc: 0.9489\n",
      "val Loss: 0.4777 Acc: 0.8337\n",
      "\n",
      "Epoch 17/25\n",
      "----------\n",
      "train Loss: 0.1416 Acc: 0.9486\n",
      "val Loss: 0.3117 Acc: 0.8896\n",
      "\n",
      "Epoch 18/25\n",
      "----------\n",
      "train Loss: 0.1347 Acc: 0.9535\n",
      "val Loss: 0.3968 Acc: 0.8589\n",
      "\n",
      "Epoch 19/25\n",
      "----------\n",
      "train Loss: 0.1332 Acc: 0.9529\n",
      "val Loss: 0.4839 Acc: 0.8355\n",
      "\n",
      "Epoch 20/25\n",
      "----------\n",
      "train Loss: 0.1292 Acc: 0.9539\n",
      "val Loss: 0.4638 Acc: 0.8407\n",
      "\n",
      "Epoch 21/25\n",
      "----------\n",
      "train Loss: 0.1267 Acc: 0.9551\n",
      "val Loss: 0.8665 Acc: 0.6976\n",
      "\n",
      "Epoch 22/25\n",
      "----------\n",
      "train Loss: 0.1207 Acc: 0.9584\n",
      "val Loss: 0.3124 Acc: 0.8930\n",
      "\n",
      "Epoch 23/25\n",
      "----------\n",
      "train Loss: 0.1206 Acc: 0.9576\n",
      "val Loss: 0.3131 Acc: 0.8883\n",
      "\n",
      "Epoch 24/25\n",
      "----------\n",
      "train Loss: 0.1214 Acc: 0.9579\n",
      "val Loss: 0.3089 Acc: 0.8932\n",
      "\n",
      "Epoch 25/25\n",
      "----------\n",
      "train Loss: 0.1191 Acc: 0.9582\n",
      "val Loss: 0.3261 Acc: 0.8898\n",
      "\n",
      "Training complete in 150m 1s\n",
      "Best val Acc: 0.893219\n"
     ]
    }
   ],
   "source": [
    "model_152 = train_model(model_152, criterion, optimizer_152, exp_lr_scheduler,\n",
    "                        num_epochs=25)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T12:04:42.905120336Z",
     "start_time": "2023-06-16T09:34:41.315141170Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [],
   "source": [
    "dir_saved_models = 'saved_models'\n",
    "torch.save(model_152, dir_saved_models + f'/ResNet152_16_06')"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-07-11T17:08:15.845417910Z",
     "start_time": "2023-07-11T17:08:15.543067366Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [],
   "source": [
    "model_loaded = torch.load('saved_models/ResNet152_16_06')\n",
    "model_loaded.to(device)\n",
    "model_loaded.eval()\n",
    "dl_for_all = torch.utils.data.DataLoader(dataset_all, batch_size=1, num_workers=4)\n",
    "predictions = []\n",
    "with torch.no_grad():\n",
    "    for i, l in dl_for_all:\n",
    "        i = i.to(device, dtype=torch.float32)\n",
    "        l.to(device)\n",
    "        all_outputs = model_loaded(i)\n",
    "        predictions.append(all_outputs)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-07-11T17:13:09.008735349Z",
     "start_time": "2023-07-11T17:08:16.944623845Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [],
   "source": [
    "import gc\n",
    "\n",
    "model_loaded.cpu()\n",
    "del model_loaded\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T15:42:53.982435329Z",
     "start_time": "2023-06-16T15:42:53.757373364Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "only one element tensors can be converted to Python scalars",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mValueError\u001B[0m                                Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[11], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mtensor\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpredictions\u001B[49m\u001B[43m)\u001B[49m\n",
      "\u001B[0;31mValueError\u001B[0m: only one element tensors can be converted to Python scalars"
     ]
    }
   ],
   "source": [
    "torch.tensor(predictions)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:01:26.762157551Z",
     "start_time": "2023-06-16T16:01:26.725691981Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "data": {
      "text/plain": "19248"
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(predictions)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T15:54:02.047667158Z",
     "start_time": "2023-06-16T15:54:02.016976474Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "outputs": [],
   "source": [
    "prrrred = [torch.max(prediction, 1)[1] for prediction in predictions]\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:06:32.496453506Z",
     "start_time": "2023-06-16T16:06:32.299392149Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "outputs": [
    {
     "data": {
      "text/plain": "False"
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prrrred == targets"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:06:35.273641843Z",
     "start_time": "2023-06-16T16:06:35.266507137Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "outputs": [
    {
     "data": {
      "text/plain": "0.47604945968412304"
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "good = 0\n",
    "for i in range(len(targets)):\n",
    "    if targets[i] == prrrred[i]:\n",
    "        good += 1\n",
    "good / len(targets)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-06-16T16:08:22.114824871Z",
     "start_time": "2023-06-16T16:08:21.633745182Z"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
